{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.6.8\n",
      "IPython 7.2.0\n",
      "\n",
      "torch 1.0.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Zoo -- Getting Gradients of an Intermediate Variable in PyTorch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook illustrates how we can fetch the intermediate gradients of a function that is composed of multiple inputs and multiple computation steps in PyTorch. Note that gradient is simply a vector listing the derivatives of a function with respect\n",
    "to each argument of the function. So, strictly speaking, we are discussing how to obtain the partial derivatives here."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Assume we have this simple toy graph:\n",
    "    \n",
    "![](../images/manual-gradients/graph_1.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we provide the following values to b, x, and w; the red numbers indicate the intermediate values of the computation and the end result:\n",
    "\n",
    "![](../images/manual-gradients/graph_2.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, the next image shows the partial derivatives of the output node, a, with respect to the input nodes (b, x, and w) as well as all the intermediate partial derivatives:\n",
    "\n",
    "\n",
    "![](../images/manual-gradients/graph_3.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(The images were taken from my PyData Talk in August 2017, for more information of how to arrive at these derivatives, please see the talk/slides at https://github.com/rasbt/pydata-annarbor2017-dl-tutorial; also, I put up a little calculus and differentiation primer if helpful: https://sebastianraschka.com/pdf/books/dlb/appendix_d_calculus.pdf)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For instance, if we are interested in obtaining the partial derivative of the output a with respect to each of the input and intermediate nodes, we could do the following in TensorFlow, where `d_a_b` denotes \"partial derivative of a with respect to b\" and so forth:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[2.0], [3.0], [1.0], [1.0], [1.0]]\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "g = tf.Graph()\n",
    "with g.as_default() as g:\n",
    "    \n",
    "    x = tf.placeholder(dtype=tf.float32, shape=None, name='x')\n",
    "    w = tf.Variable(initial_value=2, dtype=tf.float32, name='w')\n",
    "    b = tf.Variable(initial_value=1, dtype=tf.float32, name='b')\n",
    "    \n",
    "    u = x * w\n",
    "    v = u + b\n",
    "    a = tf.nn.relu(v)\n",
    "    \n",
    "    d_a_x = tf.gradients(a, x)\n",
    "    d_a_w = tf.gradients(a, w)\n",
    "    d_a_b = tf.gradients(a, b)\n",
    "    d_a_u = tf.gradients(a, u)\n",
    "    d_a_v = tf.gradients(a, v)\n",
    "\n",
    "\n",
    "with tf.Session(graph=g) as sess:\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    grads = sess.run([d_a_x, d_a_w, d_a_b, d_a_u, d_a_v], feed_dict={'x:0': 3})\n",
    "\n",
    "print(grads)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Intermediate Gradients in PyTorch via autograd's `grad`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In PyTorch, there are multiple ways to compute partial derivatives or gradients. If the goal is to just compute partial derivatives, the most straight-forward way would be using autograd's `grad` function. By default, the `retain_graph` parameter of the `grad` function is set to `False`, which will free the graph after computing the partial derivative. Thus, if we want to obtain multiple partial derivatives, we need to set `retain_graph=True`. Note that this is a very inefficient solution though, as multiple passes over the graph are being made where intermediate results are being recalculated:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "d_a_x: (tensor([2.]),)\n",
      "d_a_w: (tensor([3.]),)\n",
      "d_a_b: (tensor([1.]),)\n",
      "d_a_u: (tensor([1.]),)\n",
      "d_a_v: (tensor([1.]),)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import grad\n",
    "\n",
    "\n",
    "x = torch.tensor([3.], requires_grad=True)\n",
    "w = torch.tensor([2.], requires_grad=True)\n",
    "b = torch.tensor([1.], requires_grad=True)\n",
    "\n",
    "u = x * w\n",
    "v = u + b\n",
    "a = F.relu(v)\n",
    "\n",
    "d_a_b = grad(a, b, retain_graph=True)\n",
    "d_a_u = grad(a, u, retain_graph=True)\n",
    "d_a_v = grad(a, v, retain_graph=True)\n",
    "d_a_w = grad(a, w, retain_graph=True)\n",
    "d_a_x = grad(a, x)\n",
    "    \n",
    "\n",
    "for name, grad in zip(\"xwbuv\", (d_a_x, d_a_w, d_a_b, d_a_u, d_a_v)):\n",
    "    print('d_a_%s:' % name, grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As suggested by Adam Paszke, this can be made rewritten in a more efficient manner by passing a tuple to the `grad` function so that it can reuse intermediate results and only require one pass over the graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "d_a_x: tensor([2.])\n",
      "d_a_w: tensor([3.])\n",
      "d_a_b: tensor([1.])\n",
      "d_a_u: tensor([1.])\n",
      "d_a_v: tensor([1.])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import grad\n",
    "\n",
    "\n",
    "x = torch.tensor([3.], requires_grad=True)\n",
    "w = torch.tensor([2.], requires_grad=True)\n",
    "b = torch.tensor([1.], requires_grad=True)\n",
    "\n",
    "u = x * w\n",
    "v = u + b\n",
    "a = F.relu(v)\n",
    "\n",
    "partial_derivatives = grad(a, (x, w, b, u, v))\n",
    "\n",
    "for name, grad in zip(\"xwbuv\", (partial_derivatives)):\n",
    "    print('d_a_%s:' % name, grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Intermediate Gradients in PyTorch via `retain_grad`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In PyTorch, we most often use the `backward()` method on an output variable to compute its partial derivative (or gradient) with respect to its inputs (typically, the weights and bias units of a neural network). By default, PyTorch only stores the gradients of the leaf variables (e.g., the weights and biases) via their `grad` attribute to save memory. So, if we are interested in the intermediate results in a computational graph, we can use the `retain_grad` method to store gradients of non-leaf variables as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "d_a_x: tensor([2.])\n",
      "d_a_w: tensor([3.])\n",
      "d_a_b: tensor([1.])\n",
      "d_a_u: tensor([1.])\n",
      "d_a_v: tensor([1.])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "\n",
    "\n",
    "x = torch.tensor([3.], requires_grad=True)\n",
    "w = torch.tensor([2.], requires_grad=True)\n",
    "b = torch.tensor([1.], requires_grad=True)\n",
    "\n",
    "u = x * w\n",
    "v = u + b\n",
    "a = F.relu(v)\n",
    "\n",
    "u.retain_grad()\n",
    "v.retain_grad()\n",
    "\n",
    "a.backward()\n",
    "\n",
    "for name, var in zip(\"xwbuv\", (x, w, b, u, v)):\n",
    "    print('d_a_%s:' % name, var.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Intermediate Gradients in PyTorch Using Hooks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, and this is a not-recommended workaround, we can use hooks to obtain intermediate gradients. While the two other approaches explained above should be preferred, this approach highlights the use of hooks, which may come in handy in certain situations.\n",
    "\n",
    "> The hook will be called every time a gradient with respect to the variable is computed.  (http://pytorch.org/docs/master/autograd.html#torch.autograd.Variable.register_hook)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Based on the suggestion by Adam Paszke (https://discuss.pytorch.org/t/why-cant-i-see-grad-of-an-intermediate-variable/94/7?u=rasbt), we can use these hooks in a combintation with a little helper function, `save_grad` and a `hook` closure writing the partial derivatives or gradients to a global variable `grads`. So, if we invoke the `backward` method on the output node `a`, all the intermediate results will be collected in `grads`, as illustrated below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'d_a_v': tensor([1.]),\n",
       " 'd_a_b': tensor([1.]),\n",
       " 'd_a_u': tensor([1.]),\n",
       " 'd_a_x': tensor([2.]),\n",
       " 'd_a_w': tensor([3.])}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "grads = {}\n",
    "def save_grad(name):\n",
    "    def hook(grad):\n",
    "        grads[name] = grad\n",
    "    return hook\n",
    "\n",
    "\n",
    "x = torch.tensor([3.], requires_grad=True)\n",
    "w = torch.tensor([2.], requires_grad=True)\n",
    "b = torch.tensor([1.], requires_grad=True)\n",
    "\n",
    "u = x * w\n",
    "v = u + b\n",
    "\n",
    "x.register_hook(save_grad('d_a_x'))\n",
    "w.register_hook(save_grad('d_a_w'))\n",
    "b.register_hook(save_grad('d_a_b'))\n",
    "u.register_hook(save_grad('d_a_u'))\n",
    "v.register_hook(save_grad('d_a_v'))\n",
    "\n",
    "a = F.relu(v)\n",
    "\n",
    "a.backward()\n",
    "\n",
    "grads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensorflow  1.12.0\n",
      "torch       1.0.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
